{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "77be7dda",
   "metadata": {},
   "source": [
    "## vllm-lmi rollingbatch Yi-34B-chat-4bits deployment guide\n",
    "\n",
    "### In this tutorial, you will use vllm backend of Large Model Inference(LMI) DLC to deploy Yi-34B-chat-4bits and run inference with it.\n",
    "\n",
    "Please make sure the following permission granted before running the notebook:\n",
    "\n",
    "* SageMaker access"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7bc171a0",
   "metadata": {},
   "source": [
    "### Step 1: Let's bump up SageMaker and import stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdd49db4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%pip install sagemaker --upgrade  --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc3c2465",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "from sagemaker.djl_inference.model import DJLModel\n",
    "\n",
    "role = sagemaker.get_execution_role()  # execution role for the endpoint\n",
    "session = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86df559e",
   "metadata": {},
   "source": [
    "## Step 2: Start building SageMaker endpoint\n",
    "In this step, we will build SageMaker endpoint from scratch\n",
    "\n",
    "### Getting the container image URI (optional)\n",
    "\n",
    "Check out available images: [Large Model Inference available DLC](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#large-model-inference-containers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79c21521",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Choose a specific version of LMI image directly:\n",
    "# image_uri = \"763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a817d0e",
   "metadata": {},
   "source": [
    "### Create SageMaker model\n",
    "\n",
    "Here we are using [LMI PySDK](https://sagemaker.readthedocs.io/en/stable/frameworks/djl/using_djl.html) to create the model.\n",
    "\n",
    "Checkout more [configuration options](https://docs.djl.ai/docs/serving/serving/docs/lmi/deployment_guide/configurations.html#environment-variable-configurations)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f903af58",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model_id = \"01-ai/Yi-34B-Chat-4bits\" # model will be download form Huggingface hub\n",
    "\n",
    "env = {\n",
    "    \"TENSOR_PARALLEL_DEGREE\": \"4\",            # use 4 GPUs\n",
    "    \"OPTION_ROLLING_BATCH\": \"vllm\",           # use vllm for rolling batching\n",
    "    \"OPTION_TRUST_REMOTE_CODE\": \"true\",\n",
    "}\n",
    "\n",
    "model = DJLModel(\n",
    "            model_id=model_id,\n",
    "            env=env,\n",
    "            role=role)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c04c778f",
   "metadata": {},
   "source": [
    "### Create SageMaker endpoint\n",
    "\n",
    "You need to specify the instance to use and endpoint names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f27a8df",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "instance_type = \"ml.g4dn.12xlarge\"\n",
    "endpoint_name = sagemaker.utils.name_from_base(\"lmi-model\")\n",
    "\n",
    "predictor = model.deploy(initial_instance_count=1,\n",
    "             instance_type=instance_type,\n",
    "             endpoint_name=endpoint_name,\n",
    "             container_startup_health_check_timeout=1800\n",
    "            )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19004557",
   "metadata": {},
   "source": [
    "## Step 3: Test and benchmark the inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd192a22-3cda-4032-bdfd-0491e88cc068",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "input_text = \"<|im_start|>user\\n世界上第二高的山峰是哪座<|im_end|>\\n<|im_start|>assistant\\n\"\n",
    "\n",
    "parameters = {\n",
    "    \"max_new_tokens\":128,\n",
    "    \"do_sample\":True,\n",
    "    \"temperature\": 0.6,\n",
    "    \"eos_token_id\": 7,\n",
    "    \"top_p\": 0.8\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fda8f387-680f-458b-85df-d5dee01174ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.predict(\n",
    "    {\"inputs\": input_text, \"parameters\": {\"max_new_tokens\":128, \"do_sample\":True}}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f04cbcd1-3841-44da-b9fd-4b96e4c72783",
   "metadata": {},
   "source": [
    "#### Streaming"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd17ba98-ffd1-49ed-95d9-10840c2fdae4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import json\n",
    "import boto3\n",
    "from utils.LineIterator import LineIterator\n",
    "\n",
    "smr_client = boto3.client(\"sagemaker-runtime\")\n",
    "def get_realtime_response_stream(sagemaker_runtime, endpoint_name, payload):\n",
    "    response_stream = sagemaker_runtime.invoke_endpoint_with_response_stream(\n",
    "        EndpointName=endpoint_name,\n",
    "        Body=json.dumps(payload), \n",
    "        ContentType=\"application/json\",\n",
    "        CustomAttributes='accept_eula=false'\n",
    "    )\n",
    "    return response_stream\n",
    "\n",
    "\n",
    "def print_response_stream(response_stream):\n",
    "    event_stream = response_stream.get('Body')\n",
    "    for line in LineIterator(event_stream):\n",
    "        print(line, end='')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69410da3-91e0-459e-886c-61c282700507",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "payload = {\n",
    "    \"inputs\":  input_text,\n",
    "    \"parameters\": parameters,\n",
    "    \"stream\": True ## <-- to have response stream.\n",
    "}\n",
    "response_stream = get_realtime_response_stream(smr_client, endpoint_name, payload)\n",
    "print_response_stream(response_stream)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd8073fb-0880-4697-aedf-4711918cacbd",
   "metadata": {},
   "source": [
    "## Clean up the environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9603928-b375-4d87-8c29-63b3cd59306a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "session.delete_endpoint(endpoint_name)\n",
    "session.delete_endpoint_config(endpoint_name)\n",
    "model.delete_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "569e42e2-b80b-41f9-812e-de39690be86d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
